Load packages.
# machine learning
library("DALEX")
library("ranger")
# visualization
library("patchwork")
library("ggplot2")Happiness data.
# load data
train <- read.csv("data/happiness_train.csv", row.names = 1)
test <- read.csv("data/happiness_test.csv", row.names = 1)
knitr::kable(rbind(head(test), tail(test)))| score | gdp_per_capita | social_support | healthy_life_expectancy | freedom_life_choices | generosity | perceptions_of_corruption | |
|---|---|---|---|---|---|---|---|
| Finland | 7.769 | 1.340 | 1.587 | 0.986 | 0.596 | 0.153 | 0.393 |
| Denmark | 7.600 | 1.383 | 1.573 | 0.996 | 0.592 | 0.252 | 0.410 |
| Norway | 7.554 | 1.488 | 1.582 | 1.028 | 0.603 | 0.271 | 0.341 |
| Iceland | 7.494 | 1.380 | 1.624 | 1.026 | 0.591 | 0.354 | 0.118 |
| Netherlands | 7.488 | 1.396 | 1.522 | 0.999 | 0.557 | 0.322 | 0.298 |
| Switzerland | 7.480 | 1.452 | 1.526 | 1.052 | 0.572 | 0.263 | 0.343 |
| Yemen | 3.380 | 0.287 | 1.163 | 0.463 | 0.143 | 0.108 | 0.077 |
| Rwanda | 3.334 | 0.359 | 0.711 | 0.614 | 0.555 | 0.217 | 0.411 |
| Tanzania | 3.231 | 0.476 | 0.885 | 0.499 | 0.417 | 0.276 | 0.147 |
| Afghanistan | 3.203 | 0.350 | 0.517 | 0.361 | 0.000 | 0.158 | 0.025 |
| Central African Republic | 3.083 | 0.026 | 0.000 | 0.105 | 0.225 | 0.235 | 0.035 |
| South Sudan | 2.853 | 0.306 | 0.575 | 0.295 | 0.010 | 0.202 | 0.091 |
Explain a black-box model.
# fit a model
model_rf <- ranger(score~., data = train)
# create an explainer for the model
explainer_rf <- explain(model_rf,
data = test[,-1],
y = test$score)## Preparation of a new explainer is initiated
## -> model label : ranger ( [33m default [39m )
## -> data : 156 rows 6 cols
## -> target variable : 156 values
## -> predict function : yhat.ranger will be used ( [33m default [39m )
## -> predicted values : No value for predict function target column. ( [33m default [39m )
## -> model_info : package ranger , ver. 0.12.1 , task regression ( [33m default [39m )
## -> predicted values : numerical, min = 3.29285 , mean = 5.466507 , max = 6.786591
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -2.157877 , mean = -0.0594106 , max = 1.117275
## [32m A new explainer has been created! [39m
(plot(model_parts(explainer_rf), subtitle="")) /
(plot(model_profile(explainer_rf), subtitle="") +
theme(axis.title.y = element_text(vjust = -40))) +
plot_layout(heights = c(1, 2))obs <- test[1,]
pp <- predict_parts(explainer_rf, obs)
pp$label <- rownames(obs)
(plot(pp, subtitle="")) /
(plot(predict_profile(explainer_rf, obs), subtitle="") +
theme(axis.title.y = element_text(vjust = -50))) +
plot_layout(heights = c(1, 2))Just one line of code…
library("modelStudio")
modelStudio(explainer_rf,
options = ms_options(margin_left = 150))Exploring the parameters based on the vignette and documentation.
Observations for local explanations
new_observation <- test[c('Canada', 'Chile', 'China'), ]
modelStudio(explainer_rf,
new_observation = new_observation,
new_observation_y = new_observation$score)
modelStudio(explainer_rf,
new_observation_n = 10)Smaller studio
modelStudio(explainer_rf,
facet_dim = c(1, 2),
options = ms_options(margin_left = 150))“I have a huge monitor!”
modelStudio(explainer_rf,
facet_dim = c(2, 3),
widget_id = "user",
options = ms_options(margin_left = 150))Longer computation for more accurate results
modelStudio(explainer_rf,
N = 300,
N_fi = 3000,
B = 15,
B_fi = 25)ms_optionsmodelStudio(explainer_rf,
max_vars = 4,
time = 100,
eda = FALSE,
options = ms_options(
margin_left = 150,
ms_title = "modelStudio on useR!-21",
ms_subtitle = "https://tinyurl.com/RML2021",
line_size = 4,
cp_point_size = 6,
positive_color = "#ffa58c",
negative_color = "#ae2c87"
))ms_update_observations() & ms_update_options()ms <- modelStudio(explainer_rf)
ms_updated <- ms_update_observations(ms, explainer_rf, test[1:3, ])
ms_updated_again <- ms_update_options(ms_updated, facet_dim = c(1, 2), margin_left = 150)
ms_updated_again